Skip to content

[TRTLLM-12339][feat] enable TRTLLM cross attention backend#15345

Merged
cascade812 merged 2 commits into
NVIDIA:mainfrom
cascade812:codex/split-attention-op-trtllm
Jun 16, 2026
Merged

[TRTLLM-12339][feat] enable TRTLLM cross attention backend#15345
cascade812 merged 2 commits into
NVIDIA:mainfrom
cascade812:codex/split-attention-op-trtllm

Conversation

@cascade812

@cascade812 cascade812 commented Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Description

Split out the attention operator and TRTLLM attention backend changes from #13919 to reduce frequent conflicts with main and make CI validation easier for this smaller, self-contained scope.

This PR intentionally keeps the change self-contained:

  • wires thop.attention and its nanobind signature for cross-attention and relative-attention-bias inputs
  • enables the TRTLLM backend path for cross-attention metadata, including Q padding, cross K/V forwarding, and beam-width handling
  • makes trtllm-gen decline cross-attention so cross requests use the THOP path
  • adds only the small backend forward-args fields required by the TRTLLM backend

No module, executor, model, or LLM API caller changes are included in this split.

Summary by CodeRabbit

  • New Features
    • Added cross-attention support with optional cross-key-value tensor inputs.
    • Added optional relative attention bias with configurable maximum distance parameter.

Signed-off-by: Guiju Zhang <guijuz@nvidia.com>
@cascade812 cascade812 force-pushed the codex/split-attention-op-trtllm branch from 7537f51 to 46ef1af Compare June 14, 2026 04:03
@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54082 [ run ] triggered by Bot. Commit: 46ef1af Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54082 [ run ] completed with state SUCCESS. Commit: 46ef1af
/LLM/main/L0_MergeRequest_PR pipeline #43166 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54328 [ run ] triggered by Bot. Commit: 3dfaeac Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54328 [ run ] completed with state FAILURE. Commit: 3dfaeac
/LLM/main/L0_MergeRequest_PR pipeline #43399 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54380 [ run ] triggered by Bot. Commit: 3dfaeac Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54380 [ run ] completed with state SUCCESS. Commit: 3dfaeac
/LLM/main/L0_MergeRequest_PR pipeline #43450 completed with status: 'SUCCESS'

CI Report

Link to invocation

@cascade812 cascade812 marked this pull request as ready for review June 16, 2026 01:15
@cascade812 cascade812 requested a review from a team as a code owner June 16, 2026 01:15
@cascade812 cascade812 requested a review from QiJune June 16, 2026 01:15

@brb-nv brb-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@cascade812 cascade812 merged commit 2ef2ea5 into NVIDIA:main Jun 16, 2026
11 checks passed
@cascade812 cascade812 deleted the codex/split-attention-op-trtllm branch June 16, 2026 01:27
@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Caution

Review failed

Pull request was closed or merged during review

📝 Walkthrough

Walkthrough

This PR adds cross-attention and relative position bias support end-to-end: AttentionForwardArgs gains cross_kv, relative_attention_bias, and relative_attention_max_distance fields; the C++ RunnerBase/Runner and top-level attention() are extended to validate and wire these inputs; TrtllmAttention._run prepares the cross-KV tensor layout; trtllm-gen is guarded to reject cross-attention.

Changes

Cross-attention and Relative Position Bias

Layer / File(s) Summary
Public API contracts
cpp/tensorrt_llm/thop/attentionOp.h, tensorrt_llm/_torch/attention_backend/interface.py, cpp/tensorrt_llm/nanobind/thop/bindings.cpp
torch_ext::attention header gains is_cross, cross_kv, relative_attention_bias, relative_attention_max_distance with defaults; AttentionForwardArgs dataclass gains matching optional fields; nanobind binding exposes the same four kwargs.
C++ runner enqueue wiring
cpp/tensorrt_llm/thop/attentionOp.cpp (lines 376–842)
RunnerBase::run virtual interface and Runner::run override add the three new parameters; runner body validates relative bias tensor shape/dtype, extracts pointer and stride into common_enqueue_params, sets encoder_input_lengths for cross attention, and in the context stage wires cross_kv pointer, num_encoder_tokens, and cross_kv_length.
Top-level attention() validation and dispatch
cpp/tensorrt_llm/thop/attentionOp.cpp (lines 1029–1351)
attention() signature extended; validation relaxed for is_cross (non-MLA paths and KV-cache update disabling); op->mCrossAttention set; relative bias validated (2D/3D dims, max-distance, embedding type) and op->mMaxDistance set; both context and generation runner->run call-sites forward the new parameters.
Python TRTLLM backend
tensorrt_llm/_torch/attention_backend/trtllm.py
TrtllmAttentionMetadata.effective_beam_width returns 1 for cross-attention; _run builds cross_kv by flattening/concatenating encoder k/v and normalizes q into fused-QKV layout; non-MLA k/v assertions gated away for cross-attention; thop.attention call-site extended with effective_beam_width, is_cross, cross_kv, and relative attention fields; forward replaces the cross-attention rejection with a three-state layout assertion.
trtllm-gen backend guard
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
FlashInferTrtllmGenAttention.is_supported() returns (False, reason) immediately when meta.is_cross is true.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • brb-nv
  • chang-l
  • zhenhuaw-me
  • yuxianq
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main change: enabling TRTLLM cross-attention backend, with appropriate JIRA reference and feature type marker.
Description check ✅ Passed The PR description provides clear context, explains the rationale for the split, and articulates the specific scope of changes covered by the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants